Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize threading of mha #20088

Merged
merged 4 commits into from
Apr 2, 2024
Merged

optimize threading of mha #20088

merged 4 commits into from
Apr 2, 2024

Conversation

yufenglee
Copy link
Member

@yufenglee yufenglee commented Mar 26, 2024

Description

The cost computation of ComputeVxAttentionScore is wrong. It should be sequence_length * v_head_size * total_sequence_length instead of sequence_length * v_head_size * sequence_length.

The PR also fine-tuned the cost computation.

on my local box with i9 cpu, the performance is same as unfused version, but it is much faster on an azure vm with 16 threads.

Motivation and Context

#19924

@BowenBao
Copy link
Contributor

Thanks @yufenglee for PR! Do you have before and after perf numbers for the repro in this issue #19924 ?

@yufenglee
Copy link
Member Author

Thanks @yufenglee for PR! Do you have before and after perf numbers for the repro in this issue #19924 ?

on my local box, it take ~1.5ms before and ~0.5ms after.

@tianleiwu
Copy link
Contributor

tianleiwu commented Mar 26, 2024

May need take a look at cost model approach to see why cost model cannot work properly since it is a fundamental for CPU EP. Maybe try use correct cost (like adding concat kv cost etc) to see whether it could resolve the issue as well.

tianleiwu
tianleiwu previously approved these changes Mar 26, 2024
@tianleiwu
Copy link
Contributor

tianleiwu commented Mar 26, 2024

Please run benchmark for comparison of before/after:

  • BERT base, batch size 1/2/4, different sequence lengths 16/64/128/256/512 on different interop_num_threads=1, 2, 4, 8
  • GPT2 for decoding, batch size 1/2/4, with past seq len 16/64/128/256/512 on different interop_num_threads=1, 2, 4, 8
    Just in case, in some situations, cost model is better.

@yihonglyu
Copy link
Contributor

Description

The CostModel of threading doesn't not work for the attention especially for decoding case. It leads to less thread to compute the attention. Each batch*num_of_head is sufficient to serve as one unit. Change to use the TrySimpleParallelFor

Motivation and Context

#19924

Would change loop_len to batch_size * num_heads_ * sequence_length solve the issue?

@yufenglee
Copy link
Member Author

The cost computation of ComputeVxAttentionScore is wrong. It should be sequence_length * v_head_size * total_sequence_length instead of sequence_length * v_head_size * sequence_length.

Also fine-tuned the cost computation for data load and store.

@yufenglee yufenglee merged commit 9165498 into main Apr 2, 2024
95 checks passed
@yufenglee yufenglee deleted the yufeng/mha_opt branch April 2, 2024 04:32
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Description
<!-- Describe your changes. -->
The cost computation of ComputeVxAttentionScore is wrong. It should be
sequence_length * v_head_size * total_sequence_length instead of
sequence_length * v_head_size * sequence_length.

The PR also fine-tuned the cost computation.

on my local box with i9 cpu, the performance is same as unfused version,
but it is much faster on an azure vm with 16 threads.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

microsoft#19924
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants